{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <div style='color:#ff00ff;font-weight:800;'>stable-baselines3库介绍</div>\n",
    "\n",
    "\n",
    "github：https://github.com/DLR-RM/stable-baselines3\n",
    "\n",
    "doc：https://stable-baselines3.readthedocs.io/en/master/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 一、stable-baselines3库是干什么的"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of Stable Baselines."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 二、为什么要用公共库\n",
    "\n",
    "简单，方便"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 三、stable-baselines3简单实例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.common.env_util import make_vec_env\n",
    "\n",
    "# Parallel environments\n",
    "env = make_vec_env(\"CartPole-v1\", n_envs=4)\n",
    "\n",
    "model = PPO(\"MlpPolicy\", env, verbose=1)\n",
    "model.learn(total_timesteps=25000)\n",
    "model.save(\"ppo_cartpole\")\n",
    "\n",
    "del model # remove to demonstrate saving and loading\n",
    "\n",
    "model = PPO.load(\"ppo_cartpole\")\n",
    "\n",
    "obs = env.reset()\n",
    "# while True:\n",
    "#     action, _states = model.predict(obs)\n",
    "#     obs, rewards, dones, info = env.step(action)\n",
    "#     env.render()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.common.env_util import make_vec_env\n",
    "\n",
    "# Parallel environments\n",
    "env = make_vec_env(\"CartPole-v1\", n_envs=4)\n",
    "\n",
    "model = PPO.load(\"ppo_cartpole\")\n",
    "\n",
    "obs = env.reset()\n",
    "# while True:\n",
    "#     action, _states = model.predict(obs)\n",
    "#     obs, rewards, dones, info = env.step(action)\n",
    "#     env.render()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 四、没有训练的效果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "env = gym.make(\"CartPole-v1\")\n",
    "\n",
    "done = True\n",
    "for step in range(5000):\n",
    "    if done:\n",
    "        state = env.reset()\n",
    "    state, reward, done, info = env.step(env.action_space.sample())\n",
    "    env.render()\n",
    "\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 五、不使用并行环境"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "\n",
    "from stable_baselines3 import PPO\n",
    "\n",
    "\n",
    "env = gym.make(\"CartPole-v1\")\n",
    "\n",
    "model = PPO(\"MlpPolicy\", env, verbose=1)\n",
    "model.learn(total_timesteps=200)\n",
    "model.save(\"ppo_cartpole1\")\n",
    "\n",
    "del model # remove to demonstrate saving and loading\n",
    "\n",
    "model = PPO.load(\"ppo_cartpole1\")\n",
    "\n",
    "obs = env.reset()\n",
    "done = True\n",
    "while True:\n",
    "    if done:\n",
    "        state = env.reset()\n",
    "    action, _states = model.predict(obs)\n",
    "    obs, rewards, done, info = env.step(action)\n",
    "    env.render()"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "9465aae7e0ab1403d672807d1a0963d86dbda2f584fbe3054c36cf78311c6c77"
  },
  "kernelspec": {
   "display_name": "Python 3.8.11 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
